import os

import PIL.Image
from mindspore.dataset import transforms
from mindspore.dataset import vision
import mindspore.dataset as ds
from mindspore import context
from dataset.dataset import get_cc12m_dataset
from mindspore import Tensor
import mindspore
import numpy as np

if __name__ == '__main__':
    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
    data_dir = ['./mindrecord/cc12m.mindrecord00', './mindrecord/cc12m.mindrecord01', './mindrecord/cc12m.mindrecord02',
                './mindrecord/cc12m.mindrecord03', './mindrecord/cc12m.mindrecord04', './mindrecord/cc12m.mindrecord05',
                './mindrecord/cc12m.mindrecord06', './mindrecord/cc12m.mindrecord07', './mindrecord/cc12m.mindrecord08',
                './mindrecord/cc12m.mindrecord09', './mindrecord/cc12m.mindrecord10', './mindrecord/cc12m.mindrecord11',
                './mindrecord/cc12m.mindrecord12', './mindrecord/cc12m.mindrecord13', './mindrecord/cc12m.mindrecord14',
                './mindrecord/cc12m.mindrecord15', './mindrecord/cc12m.mindrecord16', './mindrecord/cc12m.mindrecord17',
                './mindrecord/cc12m.mindrecord18', './mindrecord/cc12m.mindrecord19', './mindrecord/cc12m.mindrecord20',
                './mindrecord/cc12m.mindrecord21', './mindrecord/cc12m.mindrecord22', './mindrecord/cc12m.mindrecord23',
                './mindrecord/cc12m.mindrecord24', './mindrecord/cc12m.mindrecord25', './mindrecord/cc12m.mindrecord26',
                './mindrecord/cc12m.mindrecord27', './mindrecord/cc12m.mindrecord28', './mindrecord/cc12m.mindrecord29',
                './mindrecord/cc12m.mindrecord30', './mindrecord/cc12m.mindrecord31', './mindrecord/cc12m.mindrecord32',
                './mindrecord/cc12m.mindrecord33', './mindrecord/cc12m.mindrecord34', './mindrecord/cc12m.mindrecord35',
                './mindrecord/cc12m.mindrecord36', './mindrecord/cc12m.mindrecord37', './mindrecord/cc12m.mindrecord38',
                './mindrecord/cc12m.mindrecord39', './mindrecord/cc12m.mindrecord40', './mindrecord/cc12m.mindrecord41',
                './mindrecord/cc12m.mindrecord42', './mindrecord/cc12m.mindrecord43', './mindrecord/cc12m.mindrecord44',
                './mindrecord/cc12m.mindrecord45', './mindrecord/cc12m.mindrecord46', './mindrecord/cc12m.mindrecord47',
                './mindrecord/cc12m.mindrecord48', './mindrecord/cc12m.mindrecord49']
    data_dir2 = './mindrecord/cc12m.mindrecord'
    count = 0
    # dataset = get_cc12m_dataset(data_dir2, ['file_name', 'image', 'token'], 1, False, 1, 0, 1)
    for i in range(50):
        dataset = get_cc12m_dataset(data_dir[i], ['file_name'], 1, False, 1, 0, 1)
    for item in dataset.create_dict_iterator(output_numpy=False):
        # print("sample: {}".format(item))
        count += 1
    print("Got {} samples".format(count))

